from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility


# 在向量数据库中创建表
def create_milvus_collection(collection_name, dim, metric_type='COSINE', host='127.0.0.1'):
    connections.connect(host=host, port='19530')

    if utility.has_collection(collection_name):
        milvus_collection = Collection(name=collection_name)
        return milvus_collection

    fields = [
        FieldSchema(name='id', dtype=DataType.VARCHAR, description='ids', is_primary=True, max_length=128,
                    auto_id=False),
        FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, description='embedding vectors', dim=dim),
        FieldSchema(name='user_id', dtype=DataType.VARCHAR, description='user id', max_length=128, ),
        FieldSchema(name='image_path', dtype=DataType.VARCHAR, description='image path', max_length=512, )
    ]
    schema = CollectionSchema(fields=fields, description='multimodal search')
    milvus_collection = Collection(name=collection_name, schema=schema)

    # Create IVF_FLAT index for the collection.
    index_params = {
        'metric_type': metric_type,
        'index_type': 'IVF_FLAT',
        'params': {"nlist": 128}

    }
    milvus_collection.create_index(field_name="vector", index_params=index_params, index_name='idx_vector')

    # 创建 user_id 的索引
    index_params = {'index_type': "marisa-trie"}
    milvus_collection.create_index(field_name='user_id', index_params=index_params, index_name='idx_user_id')

    return milvus_collection


if __name__ == "__main__":
    # multimodal_search_H14_L2  multimodal_search_H14  multimodal_search_L14
    host = 'localhost'
    collection_name = 'multimodal_search_H14'
    dim = 1024
    # 'L2' 'COSINE'
    metric_type = 'COSINE'
    collection = create_milvus_collection(collection_name, dim, metric_type, host)
    print(collection)
